Context

Cardiovascular diseases (CVDs) are the number 1 cause of death globally, taking an estimated 17.9 million lives each year, which accounts for 31% of all deaths worldwide. Four out of 5CVD deaths are due to heart attacks and strokes, and one-third of these deaths occur prematurely in people under 70 years of age. Heart failure is a common event caused by CVDs and this dataset contains 11 features that can be used to predict a possible heart disease.

People with cardiovascular disease or who are at high cardiovascular risk (due to the presence of one or more risk factors such as hypertension, diabetes, hyperlipidaemia or already established disease) need early detection and management wherein a machine learning model can be of great help.

Attribute Information

1) Age: age of the patient [years]

2) Sex: sex of the patient [M: Male, F: Female]

3) ChestPainType: chest pain type [TA: Typical Angina, ATA: Atypical Angina, NAP: Non-Anginal Pain, ASY: Asymptomatic]

4) RestingBP: resting blood pressure [mm Hg]

5) Cholesterol: serum cholesterol [mm/dl]

6) FastingBS: fasting blood sugar [1: if FastingBS > 120 mg/dl, 0: otherwise]

7) RestingECG: resting electrocardiogram results [Normal: Normal, ST: having ST-T wave abnormality (T wave inversions and/or ST elevation or depression of > 0.05 mV), LVH: showing probable or definite left ventricular hypertrophy by Estes' criteria]

8) MaxHR: maximum heart rate achieved [Numeric value between 60 and 202]

9) ExerciseAngina: exercise-induced angina [Y: Yes, N: No]

10) Oldpeak: oldpeak = ST [Numeric value measured in depression]

11) ST_Slope: the slope of the peak exercise ST segment [Up: upsloping, Flat: flat, Down: downsloping]

12) HeartDisease: output class [1: heart disease, 0: Normal]

Source

This dataset was created by combining different datasets already available independently but not combined before. In this dataset, 5 heart datasets are combined over 11 common features which makes it the largest heart disease dataset available so far for research purposes. The five datasets used for its curation are:

Cleveland: 303 observations Hungarian: 294 observations Switzerland: 123 observations Long Beach VA: 200 observations Stalog (Heart) Data Set: 270 observations Total: 1190 observations Duplicated: 272 observations

Final dataset: 918 observations

In [1]:
import os, types
import pandas as pd
from botocore.client import Config
import ibm_boto3

def __iter__(self): return 0


body = client_861eae432595491682c91ae0bd54b5e1.get_object(Bucket='heartfailureprediction-donotdelete-pr-ortd6bqfoibhau',Key='heart.csv')['Body']
# add missing __iter__ method, so pandas accepts body as file-like object
if not hasattr(body, "__iter__"): body.__iter__ = types.MethodType( __iter__, body )

df = pd.read_csv(body)
df.head()
Out[1]:
Age Sex ChestPainType RestingBP Cholesterol FastingBS RestingECG MaxHR ExerciseAngina Oldpeak ST_Slope HeartDisease
0 40 M ATA 140 289 0 Normal 172 N 0.0 Up 0
1 49 F NAP 160 180 0 Normal 156 N 1.0 Flat 1
2 37 M ATA 130 283 0 ST 98 N 0.0 Up 0
3 48 F ASY 138 214 0 Normal 108 Y 1.5 Flat 1
4 54 M NAP 150 195 0 Normal 122 N 0.0 Up 0

Using Pandas-Profiling for EDA

In [2]:
import sys 
!{sys.executable} -m pip install pandas-profiling
Collecting pandas-profiling
  Downloading pandas_profiling-3.2.0-py2.py3-none-any.whl (262 kB)
     |████████████████████████████████| 262 kB 18.9 MB/s eta 0:00:01
Collecting visions[type_image_path]==0.7.4
  Downloading visions-0.7.4-py3-none-any.whl (102 kB)
     |████████████████████████████████| 102 kB 24.0 MB/s ta 0:00:01
Collecting htmlmin>=0.1.12
  Downloading htmlmin-0.1.12.tar.gz (19 kB)
Requirement already satisfied: requests>=2.24.0 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from pandas-profiling) (2.26.0)
Requirement already satisfied: jinja2>=2.11.1 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from pandas-profiling) (3.0.2)
Collecting joblib~=1.1.0
  Downloading joblib-1.1.0-py2.py3-none-any.whl (306 kB)
     |████████████████████████████████| 306 kB 34.7 MB/s eta 0:00:01
Collecting markupsafe~=2.1.1
  Downloading MarkupSafe-2.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (25 kB)
Collecting multimethod>=1.4
  Downloading multimethod-1.8-py3-none-any.whl (9.8 kB)
Collecting pydantic>=1.8.1
  Downloading pydantic-1.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.2 MB)
     |████████████████████████████████| 12.2 MB 46.3 MB/s eta 0:00:01
Collecting tangled-up-in-unicode==0.2.0
  Downloading tangled_up_in_unicode-0.2.0-py3-none-any.whl (4.7 MB)
     |████████████████████████████████| 4.7 MB 58.3 MB/s eta 0:00:01
Requirement already satisfied: matplotlib>=3.2.0 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from pandas-profiling) (3.5.0)
Collecting missingno>=0.4.2
  Downloading missingno-0.5.1-py3-none-any.whl (8.7 kB)
Requirement already satisfied: tqdm>=4.48.2 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from pandas-profiling) (4.62.3)
Requirement already satisfied: scipy>=1.4.1 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from pandas-profiling) (1.7.3)
Requirement already satisfied: PyYAML>=5.0.0 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from pandas-profiling) (5.4.1)
Requirement already satisfied: numpy>=1.16.0 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from pandas-profiling) (1.20.3)
Collecting phik>=0.11.1
  Downloading phik-0.12.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (709 kB)
     |████████████████████████████████| 709 kB 69.1 MB/s eta 0:00:01
Requirement already satisfied: pandas!=1.0.0,!=1.0.1,!=1.0.2,!=1.1.0,>=0.25.3 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from pandas-profiling) (1.3.4)
Requirement already satisfied: seaborn>=0.10.1 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from pandas-profiling) (0.11.2)
Requirement already satisfied: attrs>=19.3.0 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from visions[type_image_path]==0.7.4->pandas-profiling) (21.2.0)
Requirement already satisfied: networkx>=2.4 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from visions[type_image_path]==0.7.4->pandas-profiling) (2.6.3)
Collecting imagehash
  Downloading ImageHash-4.2.1.tar.gz (812 kB)
     |████████████████████████████████| 812 kB 29.6 MB/s eta 0:00:01
Requirement already satisfied: Pillow in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from visions[type_image_path]==0.7.4->pandas-profiling) (8.4.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from matplotlib>=3.2.0->pandas-profiling) (1.3.1)
Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from matplotlib>=3.2.0->pandas-profiling) (2.8.2)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from matplotlib>=3.2.0->pandas-profiling) (0.11.0)
Requirement already satisfied: packaging>=20.0 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from matplotlib>=3.2.0->pandas-profiling) (21.3)
Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from matplotlib>=3.2.0->pandas-profiling) (4.25.0)
Requirement already satisfied: pyparsing>=2.2.1 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from matplotlib>=3.2.0->pandas-profiling) (3.0.4)
Requirement already satisfied: pytz>=2017.3 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from pandas!=1.0.0,!=1.0.1,!=1.0.2,!=1.1.0,>=0.25.3->pandas-profiling) (2021.3)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from pydantic>=1.8.1->pandas-profiling) (3.7.4.3)
Requirement already satisfied: six>=1.5 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from python-dateutil>=2.7->matplotlib>=3.2.0->pandas-profiling) (1.15.0)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from requests>=2.24.0->pandas-profiling) (3.3)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from requests>=2.24.0->pandas-profiling) (2021.10.8)
Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from requests>=2.24.0->pandas-profiling) (2.0.4)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from requests>=2.24.0->pandas-profiling) (1.26.7)
Requirement already satisfied: PyWavelets in /opt/conda/envs/Python-3.9/lib/python3.9/site-packages (from imagehash->visions[type_image_path]==0.7.4->pandas-profiling) (1.1.1)
Building wheels for collected packages: htmlmin, imagehash
  Building wheel for htmlmin (setup.py) ... done
  Created wheel for htmlmin: filename=htmlmin-0.1.12-py3-none-any.whl size=27098 sha256=401e827fd4267be499b965e19de0fc03ffa69a0fa9dc8354890797e8e3e7e26d
  Stored in directory: /tmp/wsuser/.cache/pip/wheels/1d/05/04/c6d7d3b66539d9e659ac6dfe81e2d0fd4c1a8316cc5a403300
  Building wheel for imagehash (setup.py) ... done
  Created wheel for imagehash: filename=ImageHash-4.2.1-py2.py3-none-any.whl size=295207 sha256=92eef4f656f9745aef4c87717a86cd31ff323a0b57237fd191fa6a91a6478e43
  Stored in directory: /tmp/wsuser/.cache/pip/wheels/51/f9/a5/740af2fdb0ad1edf79aabdc41531be0b6f0b2e2be684c388cf
Successfully built htmlmin imagehash
Installing collected packages: tangled-up-in-unicode, multimethod, visions, markupsafe, joblib, imagehash, pydantic, phik, missingno, htmlmin, pandas-profiling
  Attempting uninstall: markupsafe
    Found existing installation: MarkupSafe 2.0.1
    Uninstalling MarkupSafe-2.0.1:
      Successfully uninstalled MarkupSafe-2.0.1
  Attempting uninstall: joblib
    Found existing installation: joblib 0.17.0
    Uninstalling joblib-0.17.0:
      Successfully uninstalled joblib-0.17.0
Successfully installed htmlmin-0.1.12 imagehash-4.2.1 joblib-1.1.0 markupsafe-2.1.1 missingno-0.5.1 multimethod-1.8 pandas-profiling-3.2.0 phik-0.12.2 pydantic-1.9.0 tangled-up-in-unicode-0.2.0 visions-0.7.4
In [3]:
from pandas_profiling import ProfileReport
profile = ProfileReport(df, title = "Pandas Profiling Report ")
profile
Out[3]:

Convert Categorical features to numerical values

Sex:

In [4]:
df['Sex'].replace(to_replace=['M','F'], value=[0,1],inplace=True)
df.head()
Out[4]:
Age Sex ChestPainType RestingBP Cholesterol FastingBS RestingECG MaxHR ExerciseAngina Oldpeak ST_Slope HeartDisease
0 40 0 ATA 140 289 0 Normal 172 N 0.0 Up 0
1 49 1 NAP 160 180 0 Normal 156 N 1.0 Flat 1
2 37 0 ATA 130 283 0 ST 98 N 0.0 Up 0
3 48 1 ASY 138 214 0 Normal 108 Y 1.5 Flat 1
4 54 0 NAP 150 195 0 Normal 122 N 0.0 Up 0

ChestPainType:

In [5]:
ChestPainType_dummy = pd.get_dummies(df['ChestPainType'])
ChestPainType_dummy.rename(columns={'TA':'ChestPainType-TA','ATA':'ChestPainType-ATA','NAP':'ChestPainType-NAP','ASY':'ChestPainType-ASY'}, inplace=True)
df = pd.concat([df,ChestPainType_dummy],axis=1)
df.drop('ChestPainType',axis=1,inplace=True)
df.head()
Out[5]:
Age Sex RestingBP Cholesterol FastingBS RestingECG MaxHR ExerciseAngina Oldpeak ST_Slope HeartDisease ChestPainType-ASY ChestPainType-ATA ChestPainType-NAP ChestPainType-TA
0 40 0 140 289 0 Normal 172 N 0.0 Up 0 0 1 0 0
1 49 1 160 180 0 Normal 156 N 1.0 Flat 1 0 0 1 0
2 37 0 130 283 0 ST 98 N 0.0 Up 0 0 1 0 0
3 48 1 138 214 0 Normal 108 Y 1.5 Flat 1 1 0 0 0
4 54 0 150 195 0 Normal 122 N 0.0 Up 0 0 0 1 0

RestingECG:

In [6]:
RestingECG_dummy = pd.get_dummies(df['RestingECG'])
RestingECG_dummy.rename(columns={'Normal':'RestingECG-Normal','ST':'RestingECG-ST','LVH':'RestingECG-LVH'}, inplace=True)
df = pd.concat([df,RestingECG_dummy],axis=1)
df.drop('RestingECG',axis=1,inplace=True)
df.head()
Out[6]:
Age Sex RestingBP Cholesterol FastingBS MaxHR ExerciseAngina Oldpeak ST_Slope HeartDisease ChestPainType-ASY ChestPainType-ATA ChestPainType-NAP ChestPainType-TA RestingECG-LVH RestingECG-Normal RestingECG-ST
0 40 0 140 289 0 172 N 0.0 Up 0 0 1 0 0 0 1 0
1 49 1 160 180 0 156 N 1.0 Flat 1 0 0 1 0 0 1 0
2 37 0 130 283 0 98 N 0.0 Up 0 0 1 0 0 0 0 1
3 48 1 138 214 0 108 Y 1.5 Flat 1 1 0 0 0 0 1 0
4 54 0 150 195 0 122 N 0.0 Up 0 0 0 1 0 0 1 0

ExerciseAngina:

In [7]:
df['ExerciseAngina'].replace(to_replace=['N','Y'], value=[0,1],inplace=True)
df.head()
Out[7]:
Age Sex RestingBP Cholesterol FastingBS MaxHR ExerciseAngina Oldpeak ST_Slope HeartDisease ChestPainType-ASY ChestPainType-ATA ChestPainType-NAP ChestPainType-TA RestingECG-LVH RestingECG-Normal RestingECG-ST
0 40 0 140 289 0 172 0 0.0 Up 0 0 1 0 0 0 1 0
1 49 1 160 180 0 156 0 1.0 Flat 1 0 0 1 0 0 1 0
2 37 0 130 283 0 98 0 0.0 Up 0 0 1 0 0 0 0 1
3 48 1 138 214 0 108 1 1.5 Flat 1 1 0 0 0 0 1 0
4 54 0 150 195 0 122 0 0.0 Up 0 0 0 1 0 0 1 0

ST_Slope:

In [8]:
ST_Slope_dummy = pd.get_dummies(df['ST_Slope'])
ST_Slope_dummy.rename(columns={'Up':'ST_Slope-Up','Flat':'ST_Slope-Flat','Down':'ST_Slope-Down'}, inplace=True)
df = pd.concat([df,ST_Slope_dummy],axis=1)
df.drop('ST_Slope',axis=1,inplace=True)
df.head()
Out[8]:
Age Sex RestingBP Cholesterol FastingBS MaxHR ExerciseAngina Oldpeak HeartDisease ChestPainType-ASY ChestPainType-ATA ChestPainType-NAP ChestPainType-TA RestingECG-LVH RestingECG-Normal RestingECG-ST ST_Slope-Down ST_Slope-Flat ST_Slope-Up
0 40 0 140 289 0 172 0 0.0 0 0 1 0 0 0 1 0 0 0 1
1 49 1 160 180 0 156 0 1.0 1 0 0 1 0 0 1 0 0 1 0
2 37 0 130 283 0 98 0 0.0 0 0 1 0 0 0 0 1 0 0 1
3 48 1 138 214 0 108 1 1.5 1 1 0 0 0 0 1 0 0 1 0
4 54 0 150 195 0 122 0 0.0 0 0 0 1 0 0 1 0 0 0 1
In [21]:
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

Confusion Matrix Function.

In [10]:
def plot_confusion_matrix(y,y_predict):
    "this function plots the confusion matrix"
    from sklearn.metrics import confusion_matrix

    cm = confusion_matrix(y, y_predict)
    ax= plt.subplot()
    sns.heatmap(cm, annot=True, ax = ax); 
    ax.set_xlabel('Predicted labels')
    ax.set_ylabel('True labels')
    ax.set_title('Confusion Matrix'); 
    ax.xaxis.set_ticklabels(['no heart disease', 'heart disease']); ax.yaxis.set_ticklabels(['no heart disease', 'heart disease'])

Feature Selection

Let's define feature sets, X:

In [11]:
X = df[['Age','Sex','RestingBP','Cholesterol','FastingBS','MaxHR','ExerciseAngina','Oldpeak','ChestPainType-ASY','ChestPainType-ATA','ChestPainType-NAP','ChestPainType-TA','RestingECG-LVH','RestingECG-Normal','RestingECG-ST','ST_Slope-Down','ST_Slope-Flat','ST_Slope-Up']]
X.head()
Out[11]:
Age Sex RestingBP Cholesterol FastingBS MaxHR ExerciseAngina Oldpeak ChestPainType-ASY ChestPainType-ATA ChestPainType-NAP ChestPainType-TA RestingECG-LVH RestingECG-Normal RestingECG-ST ST_Slope-Down ST_Slope-Flat ST_Slope-Up
0 40 0 140 289 0 172 0 0.0 0 1 0 0 0 1 0 0 0 1
1 49 1 160 180 0 156 0 1.0 0 0 1 0 0 1 0 0 1 0
2 37 0 130 283 0 98 0 0.0 0 1 0 0 0 0 1 0 0 1
3 48 1 138 214 0 108 1 1.5 1 0 0 0 0 1 0 0 1 0
4 54 0 150 195 0 122 0 0.0 0 0 1 0 0 1 0 0 0 1

What are our lables?

In [12]:
Y = df['HeartDisease'].values
Y[0:5]
Out[12]:
array([0, 1, 0, 1, 0])

Normalize Data

In [13]:
transform = preprocessing.StandardScaler()
X = transform.fit(X).transform(X)

Model Development

Train/Test split:

In [14]:
X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size = 0.2, random_state = 2)

Logistic Regression

In [15]:
parameters ={'C':[0.01,0.1,1],
             'penalty':['l2'],
             'solver':['lbfgs']}

lr=LogisticRegression()
In [16]:
logreg_cv = GridSearchCV(lr,parameters,cv=10)
logreg_cv.fit(X_train, Y_train)
Out[16]:
GridSearchCV(cv=10, estimator=LogisticRegression(),
             param_grid={'C': [0.01, 0.1, 1], 'penalty': ['l2'],
                         'solver': ['lbfgs']})
In [17]:
print("tuned hpyerparameters :(best parameters) ",logreg_cv.best_params_)
print("accuracy :",logreg_cv.best_score_)
tuned hpyerparameters :(best parameters)  {'C': 0.1, 'penalty': 'l2', 'solver': 'lbfgs'}
accuracy : 0.8705479452054794
In [18]:
print('Accuracy is ', logreg_cv.score(X_test,Y_test))
Accuracy is  0.8369565217391305

Confusion Matrix:

In [19]:
yhat=logreg_cv.predict(X_test)
plot_confusion_matrix(Y_test,yhat)

Support Vector Machine (SVM)

In [22]:
parameters = {'kernel':('linear', 'rbf','poly','rbf', 'sigmoid'),
              'C': np.logspace(-3, 3, 5),
              'gamma':np.logspace(-3, 3, 5)}
svm = SVC()
In [23]:
svm_cv = GridSearchCV(svm, parameters, cv=10)
svm_cv.fit(X_train, Y_train)
Out[23]:
GridSearchCV(cv=10, estimator=SVC(),
             param_grid={'C': array([1.00000000e-03, 3.16227766e-02, 1.00000000e+00, 3.16227766e+01,
       1.00000000e+03]),
                         'gamma': array([1.00000000e-03, 3.16227766e-02, 1.00000000e+00, 3.16227766e+01,
       1.00000000e+03]),
                         'kernel': ('linear', 'rbf', 'poly', 'rbf', 'sigmoid')})
In [24]:
print("tuned hpyerparameters :(best parameters) ",svm_cv.best_params_)
print("accuracy :",svm_cv.best_score_)
tuned hpyerparameters :(best parameters)  {'C': 1000.0, 'gamma': 0.001, 'kernel': 'rbf'}
accuracy : 0.8774157719363197
In [25]:
print('Accuracy is', svm_cv.score(X_test, Y_test))
Accuracy is 0.8315217391304348

Confusion Matrix:

In [26]:
yhat=svm_cv.predict(X_test)
plot_confusion_matrix(Y_test,yhat)

Decision Tree

In [27]:
parameters = {'criterion': ['gini', 'entropy'],
     'splitter': ['best', 'random'],
     'max_depth': [2*n for n in range(1,10)],
     'max_features': ['auto', 'sqrt'],
     'min_samples_leaf': [1, 2, 4],
     'min_samples_split': [2, 5, 10]}

tree = DecisionTreeClassifier()
In [28]:
tree_cv = GridSearchCV(tree, parameters, cv=10)
tree_cv.fit(X_train,Y_train)
Out[28]:
GridSearchCV(cv=10, estimator=DecisionTreeClassifier(),
             param_grid={'criterion': ['gini', 'entropy'],
                         'max_depth': [2, 4, 6, 8, 10, 12, 14, 16, 18],
                         'max_features': ['auto', 'sqrt'],
                         'min_samples_leaf': [1, 2, 4],
                         'min_samples_split': [2, 5, 10],
                         'splitter': ['best', 'random']})
In [29]:
print("tuned hpyerparameters :(best parameters) ",tree_cv.best_params_)
print("accuracy :",tree_cv.best_score_)
tuned hpyerparameters :(best parameters)  {'criterion': 'entropy', 'max_depth': 6, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 10, 'splitter': 'best'}
accuracy : 0.8596815994076268
In [30]:
print('Accuracy is', tree_cv.score(X_test,Y_test))
Accuracy is 0.8206521739130435

Confusion Matrix:

In [31]:
yhat = tree_cv.predict(X_test)
plot_confusion_matrix(Y_test,yhat)

K-Nearest Neighbors (KNN)

In [32]:
parameters = {'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
              'algorithm': ['auto', 'ball_tree', 'kd_tree', 'brute'],
              'p': [1,2]}

KNN = KNeighborsClassifier()
In [33]:
knn_cv = GridSearchCV(KNN, parameters, cv=10)
knn_cv.fit(X_train, Y_train)
Out[33]:
GridSearchCV(cv=10, estimator=KNeighborsClassifier(),
             param_grid={'algorithm': ['auto', 'ball_tree', 'kd_tree', 'brute'],
                         'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
                         'p': [1, 2]})
In [34]:
print("tuned hpyerparameters :(best parameters) ",knn_cv.best_params_)
print("accuracy :",knn_cv.best_score_)
tuned hpyerparameters :(best parameters)  {'algorithm': 'auto', 'n_neighbors': 10, 'p': 1}
accuracy : 0.8705294335431321
In [35]:
print('Accuracy is', knn_cv.score(X_test,Y_test))
Accuracy is 0.8532608695652174

Confusion Matrix:

In [36]:
yhat = knn_cv.predict(X_test)
plot_confusion_matrix(Y_test,yhat)

Finding Best Model and Accuracy

In [37]:
models = {'kneighbors': knn_cv.best_score_,
         'DecisionTree': tree_cv.best_score_,
         'SVM': svm_cv.best_score_,
         'LogisticRegression': logreg_cv.best_score_ }

best_model = max(models, key = models.get)
print('The best model is',best_model, 'with a score of',models[best_model])
The best model is SVM with a score of 0.8774157719363197
In [39]:
profile.to_file("report.html")
In [ ]: